{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ONNX Script chunk\n",
    "\n",
    "下面的例子是直接从新的 PyTorch ONNX 导出器改编而来，实现了对 {func}`torch.chunk` 的支持，该函数尝试将张量分割成指定数量的块。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Sequence\n",
    "from onnxscript import opset18 as op, script, FLOAT, INT64\n",
    "\n",
    "@script()\n",
    "def aten_chunk(\n",
    "    tensor: FLOAT[...], chunks: int, dim: int = 0,\n",
    ") -> Sequence[FLOAT[...]]:\n",
    "    neg_1 = op.Constant(value_ints=[-1])\n",
    "\n",
    "    # Get size of specified dim\n",
    "    dim_size = op.Shape(tensor)[dim]\n",
    "\n",
    "    # Compute size/chunk to get the number of data in one chunk\n",
    "    num_per_chunk = dim_size / chunks + op.Cast(dim_size % chunks > 0, to=INT64.dtype)\n",
    "\n",
    "    # Compute real chunk number\n",
    "    num_chunk = dim_size / num_per_chunk\n",
    "\n",
    "    # Get something like [n, n, n, n, ...], total num_chunk\n",
    "    list_split = op.Expand(num_per_chunk, op.Reshape(num_chunk, neg_1))\n",
    "\n",
    "    remainder = dim_size % num_per_chunk\n",
    "    if remainder > 0:\n",
    "        # Append the remainder to the [n, n, n, n, ..., r]\n",
    "        list_split = op.Concat(list_split, op.Reshape(remainder, neg_1), axis=0)\n",
    "\n",
    "    return op.SplitToSequence(tensor, list_split, axis=dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们从 onnxscript 导入我们想要使用的 ONNX opset（在这个例子中是版本18）、`@script` 装饰器，以及 FLOAT 和 INT64 的张量类型。在 ONNX Script 中，张量形状是通过类型下标表示的，例如 `FLOAT[2, 10]`，或者符号性地表示为 `FLOAT[\"M\", \"N\"]`，或者在张量形状未知的情况下使用 `FLOAT[...]`。如果没有下标（仅 FLOAT），该类型旨在表示标量（秩为 0 的张量）。\n",
    "\n",
    "接下来，我们定义了一个带有类型注解的 `aten_chunk` 函数，并使用内置的 Python 语法和显式的 ONNX 算子调用来实现函数体。这个例子使用了各种二元表达式和一个 `if` 语句，但也支持许多其他的 Python 惯用构造。\n",
    "\n",
    "我们还需要定义一个简单的模型来调用我们的 ONNX Script 函数，以便我们可以导出并验证一个端到端的例子："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "@script()\n",
    "def ten_chunks_model(tensor: FLOAT[\"M\"]):\n",
    "    return aten_chunk(tensor, chunks=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这个模型将简单地将提供的张量分割成十个张量，但它也展示了 ONNX 函数当然可以调用其他 ONNX 函数，而不仅仅是内置的 ONNX 算子。\n",
    "\n",
    "我们现在将把 ONNX Script 模型导出到 ONNX，并在 [Netron](https://netron.app/) 中探索它。使用 `@script` 装饰的函数允许它们使用 `to_model_proto` 函数进行导出。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import onnx\n",
    "onnx.save_model(\n",
    "    ten_chunks_model.to_model_proto(),\n",
    "    \"ten_chunks_model.onnx\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](images/ten_chunks_model.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "图表展示了我们的两个 ONNX 函数；我们可以观察到原始的输入张量从 `ten_chunks_model` 流入 `aten_chunk`，以及属性 chunks=10。返回的是一系列最多包含 10 个张量的序列。正如人们所期望的，ONNX 中的函数可以定义一次，并在模型中任意多次调用。*[阅读更多关于核心 ONNX 概念的信息](https://onnx.ai/onnx/intro/concepts.html#input-output-node-initializer-attributes)。*"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xin",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
